-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Masked Dataset #151
base: main
Are you sure you want to change the base?
Conversation
from continuiti.transforms import Transform | ||
from continuiti.operators.shape import OperatorShapes, TensorShape | ||
|
||
|
||
class OperatorDatasetBase(td.Dataset, ABC): | ||
"""Abstract base class of a dataset for operator training.""" | ||
|
||
shapes: OperatorShapes | ||
def __init__(self, shapes: OperatorShapes, n_observations: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, shapes: OperatorShapes, n_observations: int) -> None: | |
def __init__(self, shapes: OperatorShapes, n_observations: int): |
"""Applies class transformations to four tensors. | ||
|
||
Args: | ||
src: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src: | |
src: List of tuples containing a tensor and a transformation to apply to it. |
continue | ||
out.append(transformation(src_tensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
continue | |
out.append(transformation(src_tensor)) | |
else: | |
out.append(transformation(src_tensor)) |
self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor, v: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
"""Applies class transformations to four tensors. | ||
return tensors[0], tensors[1], tensors[2], tensors[3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return tensors[0], tensors[1], tensors[2], tensors[3] | |
return tuple(tensors) |
"""A dataset for operator training containing masks in addition to tensors describing the mapping. | ||
|
||
Data, especially described on unstructured grids, can vary in the number of evaluations or sensors. Even | ||
measurements of phenomena do not always contain the same number of sensors and or evaluations. This dataset is able |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
measurements of phenomena do not always contain the same number of sensors and or evaluations. This dataset is able | |
measurements of phenomena do not always contain the same number of sensors and/or evaluations. This dataset is able |
assert not any( | ||
[torch.any(torch.isinf(mi)) for mi in member] | ||
), "Expects domain to be truncated in finite space." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this assertion necessary? Someone might come up with a good reason for using infs in the data, do we have to prevent that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see we're using inf for padding. However, does it hurt to have more infs (non-masked) in the dataset?
padding_value=torch.inf, | ||
).transpose(1, 2) | ||
values_padded = pad_sequence( | ||
[vi.transpose(0, 1) for vi in values], batch_first=True, padding_value=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[vi.transpose(0, 1) for vi in values], batch_first=True, padding_value=0 | |
[vi.transpose(0, 1) for vi in values], | |
batch_first=True, | |
padding_value=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we pad once with inf and once with 0? Seems arbitrary
mask = member_padded != torch.inf | ||
member_padded[ | ||
~mask | ||
] = 0 # mask often applied by adding a tensor with -inf values in masked locations (e.g. in scaled dot product). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is different here if why just used 0 for padding in l. 287?
|
||
return sample["x"], sample["u"], sample["y"], sample["v"] | ||
return tensors[0], tensors[1], tensors[2], tensors[3], ipt_mask, opt_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return tensors[0], tensors[1], tensors[2], tensors[3], ipt_mask, opt_mask | |
return *tuple(tensors), ipt_mask, opt_mask |
dataloader = DataLoader(dataset, batch_size=self.batch_size) | ||
|
||
for x, u, y, v, ipt_mask, opt_mask in dataloader: | ||
assert True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do more here?
Feature: Masked Dataset
Description
Not all datasets are consisent in the number of sensors and evaluations. Simulations or measurements are not only performed on a multitude of different grids, but may also contain different numbers of samples in both function spaces/sets. To reflect this the
MaskedOperatorDataset
class is introduced. It is able to handle datasets with this property.This PR introduces two new operator classes
MaskedOperatorDataset
. TheMaskedOperatorDataset
is able to process datasets with varying number of sensors or evaluations.Which issue does this PR tackle?
OperatorDataset
class is able to only handle uniform evaluation- and sensor-numbers.How does it solve the problem?
MaskedOperatorDataset
class to allow for masked sensors and evaluations.How are the changes tested?
Notes
_get_item_
method is a method specific to theOperatorDataset
andMaskedOperatorDataset
classes for separation.Checklist for Contributors
feature/title-slug
convention.Bugfix: Title
convention.Checklist for Reviewers: